{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PdpySQpViE_U"
   },
   "source": [
    "Before we can browse the rest of the notebook, we need to install the dependencies: this example uses `datasets` and `transformers`. To use TPUs on colab, we need to install `torch_xla` and the last line install `accelerate` from source since we the features we are using are very recent and not released yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "cJrsKX5AsnHp",
    "outputId": "93ea91d5-7c1b-4531-c75a-5c1f4cb4b43d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (1.6.1)\n",
      "Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.5.1)\n",
      "Requirement already satisfied: huggingface-hub<0.1.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.0.8)\n",
      "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (20.9)\n",
      "Requirement already satisfied: pyarrow>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.1.5)\n",
      "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from datasets) (3.10.1)\n",
      "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n",
      "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.41.1)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.7/dist-packages (from datasets) (2021.4.0)\n",
      "Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.3)\n",
      "Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (2.0.2)\n",
      "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.11.1)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.19.5)\n",
      "Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.45)\n",
      "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.2)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->datasets) (2.4.7)\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.1)\n",
      "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2018.9)\n",
      "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->datasets) (3.7.4.3)\n",
      "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->datasets) (3.4.1)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2020.12.5)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n",
      "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n",
      "Requirement already satisfied: cloud-tpu-client==0.10 in /usr/local/lib/python3.7/dist-packages (0.10)\n",
      "Requirement already satisfied: torch-xla==1.8 from https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl in /usr/local/lib/python3.7/dist-packages (1.8)\n",
      "Requirement already satisfied: oauth2client in /usr/local/lib/python3.7/dist-packages (from cloud-tpu-client==0.10) (4.1.3)\n",
      "Requirement already satisfied: google-api-python-client==1.8.0 in /usr/local/lib/python3.7/dist-packages (from cloud-tpu-client==0.10) (1.8.0)\n",
      "Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.7/dist-packages (from oauth2client->cloud-tpu-client==0.10) (0.2.8)\n",
      "Requirement already satisfied: six>=1.6.1 in /usr/local/lib/python3.7/dist-packages (from oauth2client->cloud-tpu-client==0.10) (1.15.0)\n",
      "Requirement already satisfied: httplib2>=0.9.1 in /usr/local/lib/python3.7/dist-packages (from oauth2client->cloud-tpu-client==0.10) (0.17.4)\n",
      "Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.7/dist-packages (from oauth2client->cloud-tpu-client==0.10) (0.4.8)\n",
      "Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from oauth2client->cloud-tpu-client==0.10) (4.7.2)\n",
      "Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.0.4)\n",
      "Requirement already satisfied: google-api-core<2dev,>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.26.3)\n",
      "Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.28.1)\n",
      "Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.1)\n",
      "Requirement already satisfied: pytz in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2018.9)\n",
      "Requirement already satisfied: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.12.4)\n",
      "Requirement already satisfied: packaging>=14.3 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (20.9)\n",
      "Requirement already satisfied: setuptools>=40.3.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (56.0.0)\n",
      "Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.23.0)\n",
      "Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.53.0)\n",
      "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (4.2.1)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=14.3->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.4.7)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.10)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.24.3)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.4)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2020.12.5)\n",
      "Collecting git+https://github.com/huggingface/accelerate\n",
      "  Cloning https://github.com/huggingface/accelerate to /tmp/pip-req-build-yqwt45eh\n",
      "  Running command git clone -q https://github.com/huggingface/accelerate /tmp/pip-req-build-yqwt45eh\n",
      "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
      "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
      "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
      "Requirement already satisfied (use --upgrade to upgrade): accelerate==0.3.0.dev0 from git+https://github.com/huggingface/accelerate in /usr/local/lib/python3.7/dist-packages\n",
      "Requirement already satisfied: pyaml>=20.4.0 in /usr/local/lib/python3.7/dist-packages (from accelerate==0.3.0.dev0) (20.4.0)\n",
      "Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from accelerate==0.3.0.dev0) (1.8.1+cu101)\n",
      "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from pyaml>=20.4.0->accelerate==0.3.0.dev0) (3.13)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.4.0->accelerate==0.3.0.dev0) (3.7.4.3)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch>=1.4.0->accelerate==0.3.0.dev0) (1.19.5)\n",
      "Building wheels for collected packages: accelerate\n",
      "  Building wheel for accelerate (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for accelerate: filename=accelerate-0.3.0.dev0-cp37-none-any.whl size=46950 sha256=a4de17afbd0510cd8a297b5e0b3f8810d1399e97bbe7e571da5bf9af8089394f\n",
      "  Stored in directory: /tmp/pip-ephem-wheel-cache-qxpoyjhd/wheels/a0/29/ed/a3def9dd6716153e8864901c4b2357423c8186ed3fc4f0ab95\n",
      "Successfully built accelerate\n"
     ]
    }
   ],
   "source": [
    "! pip install datasets transformers\n",
    "! pip install accelerate cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "! pip install git+https://github.com/huggingface/accelerate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "R1FESu6IiqOV"
   },
   "source": [
    "Here are all the imports we will need for this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "4rvtD4_mskaW",
    "outputId": "4adf87e2-2eea-4845-f3fc-af9b239ea95c"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:TPU has started up successfully with version pytorch-1.8\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from accelerate import Accelerator, DistributedType\n",
    "from datasets import load_dataset, load_metric\n",
    "from transformers import (\n",
    "    AdamW,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoTokenizer,\n",
    "    get_linear_schedule_with_warmup,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import datasets\n",
    "import transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3_Daio_mlZtt"
   },
   "source": [
    "This notebook can run with any model checkpoint on the [model hub](https://huggingface.co/models) that has a version with a classification head. Here we select [`bert-base-cased`](https://huggingface.co/bert-base-cased)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "aUWu0U_plwMD"
   },
   "outputs": [],
   "source": [
    "model_checkpoint = \"bert-base-cased\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AtBUDbklowDi"
   },
   "source": [
    "The next two sections explain how we load and prepare our data for our model, If you are only interested on seeing how 🤗 Accelerate works, feel free to skip them (but make sure to execute all cells!)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e9LqtHw7otJl"
   },
   "source": [
    "## Load the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NKjBrsTGiyf3"
   },
   "source": [
    "To load the dataset, we use the `load_dataset` function from 🤗 Datasets. It will download and cache it (so the download won't happen if we restart the notebook)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "YpUUqTXQiwur",
    "outputId": "f53cd1cd-48c6-4f5d-a36c-0e7b2a29d7b7"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:datasets.builder:Reusing dataset glue (/root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    }
   ],
   "source": [
    "raw_datasets = load_dataset(\"glue\", \"mrpc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P_j3Tnd6J_4d"
   },
   "source": [
    "The `raw_datasets` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set (with more keys for the mismatched validation and test set in the special case of `mnli`).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Be6GMC1PKIts",
    "outputId": "0a5966cf-7f32-49bd-bfb7-92dde0c529c4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
       "        num_rows: 3668\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
       "        num_rows: 408\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
       "        num_rows: 1725\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tPT8uVIrKLxm"
   },
   "source": [
    "To access an actual element, you need to select a split first, then give an index:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xl-dxrpbKNdz",
    "outputId": "c3ded223-4687-4d62-919e-2b2c9293f589"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence1': 'Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .',\n",
       " 'sentence2': 'Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .'}"
      ]
     },
     "execution_count": 6,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2YAh1I2tKPJU"
   },
   "source": [
    "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 495
    },
    "id": "Yklu4bOuKReC",
    "outputId": "692d395c-9386-4f56-cc06-a6fb2cf64443"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>idx</th>\n",
       "      <th>label</th>\n",
       "      <th>sentence1</th>\n",
       "      <th>sentence2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1302</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>John Logsdon , a board member , said that \" human spaceflight had become a place where dissent was not welcome . \"</td>\n",
       "      <td>In fact , said John Logsdon , a board member and George Washington University expert on space policy , \" human spaceflight had become a place where dissent was not welcome . \"</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2198</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>On Wednesday , the total of National Guard and Reserve members called to active duty worldwide stood at 154,603 .</td>\n",
       "      <td>As of yesterday , the total number of National Guard and Reserve troops called to duty worldwide stood at 154,603 .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3105</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>Economists said that raised the odds of a rate cut at the Fed 's next meeting on June 24-25 .</td>\n",
       "      <td>Economists said those concerns raised the odds that the Fed might reduce rates at its next meeting on June 24-25 .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3694</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>Another piece would be sent to the lab in Virginia for additional testing , Conte said .</td>\n",
       "      <td>Another piece of the suit will be sent to Virginia for more testing , Conte said .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3346</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>BofA said yesterday it had \" policies in place that prohibit late-day trading \" .</td>\n",
       "      <td>Bank of America says its policies prohibit late-day trading , which is illegal .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>894</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>We need a certifiable pay as you go budget by mid-July or schools wont open in September , Strayhorn said .</td>\n",
       "      <td>Texas lawmakers must close a $ 185.9 million budget gap by the middle of July or the schools wont open in September , Comptroller Carole Keeton Strayhorn said Thursday .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>224</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>Mr Abbas said : \" Every day without an agreement is an opportunity that was lost .</td>\n",
       "      <td>His Palestinian counterpart , Mahmoud Abbas , replied that \" every day without an agreement is an opportunity that was lost .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>3608</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>Handset market share for the second quarter , it said , is higher than the first quarter .</td>\n",
       "      <td>Nokia 's market share for the second quarter is estimated to be higher than the first quarter , 2003 . \"</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>538</td>\n",
       "      <td>equivalent</td>\n",
       "      <td>He was eventually taken to London 's Hammersmith hospital , where doctors regulated Blair 's heart beat via electric shock .</td>\n",
       "      <td>He was taken to Hammersmith hospital , where doctors regulated Blair 's heart beat via electric shock in a procedure called \" cardio-version \" .</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>3391</td>\n",
       "      <td>not_equivalent</td>\n",
       "      <td>The technology-laced Nasdaq Composite Index .IXIC inched down 1 point , or 0.11 percent , to 1,650 .</td>\n",
       "      <td>The broad Standard &amp; Poor 's 500 Index .SPX inched up 3 points , or 0.32 percent , to 970 .</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import datasets\n",
    "import random\n",
    "import pandas as pd\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def show_random_elements(dataset, num_examples=10):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "    \n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, datasets.ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "    display(HTML(df.to_html()))\n",
    "\n",
    "show_random_elements(raw_datasets[\"train\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a_FKFTuBKm8J"
   },
   "source": [
    "## Preprocess the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6StNGqH8KqLG"
   },
   "source": [
    "Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.\n",
    "\n",
    "To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:\n",
    "\n",
    "we get a tokenizer that corresponds to the model architecture we want to use,\n",
    "we download the vocabulary used when pretraining this specific checkpoint.\n",
    "That vocabulary will be cached, so it's not downloaded again the next time we run the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "rPmZCsmdKsAG"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QdUdoLF6K8Vk"
   },
   "source": [
    "By default (unless you pass `use_fast=Fast` to the call above) it will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library. Those fast tokenizers are available for almost all models, but if you got an error with the previous call, remove that argument.\n",
    "\n",
    "You can directly call this tokenizer on one sentence or a pair of sentences:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "yNuR5fCYLCr9",
    "outputId": "d903f051-b4c3-407c-9b54-d7a377675999"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 8667, 117, 1142, 1141, 5650, 106, 102, 1262, 1142, 5650, 2947, 1114, 1122, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 9,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I6YulfZrLF7C"
   },
   "source": [
    "Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LDG_YTsBLSZi"
   },
   "source": [
    "We can them write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. We also need all of our samples to have the same length (we will train on TPU and they need fixed shapes so we won't pad to the maximum length of a batch) which is done with `padding=True`. The `max_length` argument is used both for the truncation and padding (short inputs are padded to that length and long inputs are truncated to it).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "sDZmE_5cLUqE"
   },
   "outputs": [],
   "source": [
    "def tokenize_function(examples):\n",
    "    outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, padding=\"max_length\", max_length=128)\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SdPLbQGWLYAJ"
   },
   "source": [
    "This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "U3Th9LBtLaGb",
    "outputId": "c5938803-b758-4486-f8ff-b5d9a59ca31e"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [[101, 7277, 2180, 5303, 4806, 1117, 1711, 117, 2292, 1119, 1270, 107, 1103, 7737, 107, 117, 1104, 9938, 4267, 12223, 21811, 1117, 2554, 119, 102, 11336, 6732, 3384, 1106, 1140, 1112, 1178, 107, 1103, 7737, 107, 117, 7277, 2180, 5303, 4806, 1117, 1711, 1104, 9938, 4267, 12223, 21811, 1117, 2554, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 10684, 2599, 9717, 1161, 2205, 11288, 1377, 112, 188, 1196, 4147, 1103, 4129, 1106, 19770, 2787, 1107, 1772, 1111, 109, 123, 119, 126, 3775, 119, 102, 10684, 2599, 9717, 1161, 3306, 11288, 1377, 112, 188, 1107, 1876, 1111, 109, 5691, 1495, 1550, 1105, 1962, 1122, 1106, 19770, 2787, 1111, 109, 122, 119, 129, 3775, 1107, 1772, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 1220, 1125, 1502, 1126, 16355, 1113, 1103, 4639, 1113, 1340, 1275, 117, 4733, 1103, 6527, 1111, 4688, 117, 1119, 1896, 119, 102, 1212, 1340, 1275, 117, 1103, 2062, 112, 188, 5032, 1125, 1502, 1126, 16355, 1113, 1103, 4639, 117, 4733, 1103, 16454, 1111, 4688, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 5596, 5347, 19297, 14748, 1942, 117, 22515, 1830, 6117, 1127, 1146, 1627, 18748, 117, 1137, 125, 119, 125, 110, 117, 1120, 138, 109, 125, 119, 4376, 117, 1515, 2206, 1383, 170, 1647, 1344, 1104, 138, 109, 125, 119, 4667, 119, 102, 22515, 1830, 6117, 4874, 1406, 18748, 117, 1137, 125, 119, 127, 110, 117, 1106, 1383, 170, 1647, 5134, 1344, 1120, 138, 109, 125, 119, 4667, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 1109, 4482, 3152, 109, 123, 119, 1429, 117, 1137, 1164, 1429, 3029, 117, 1106, 1601, 5286, 1120, 109, 1626, 119, 4062, 1113, 1103, 1203, 1365, 9924, 7855, 119, 102, 153, 2349, 111, 142, 13619, 119, 6117, 4874, 109, 122, 119, 5519, 1137, 129, 3029, 1106, 109, 1626, 119, 5347, 1113, 1103, 1203, 1365, 9924, 7855, 1113, 5286, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}"
      ]
     },
     "execution_count": 11,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenize_function(raw_datasets['train'][:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W_ZAk-uQLdlG"
   },
   "source": [
    "To apply this function on all the sentences (or pairs of sentences) in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 164,
     "referenced_widgets": [
      "54fa72d8df5d471688bf58ce8d7d994d",
      "10cbaf06eb6e4dab836e92e51e543e1a",
      "9c95e32846454e07987c69b53b49aedd",
      "9a9e13893ca748d98077333b15b466ac",
      "f65d00a348454dfbad7ca4d3ffee1064",
      "c1de1bf5955943b78e29e15c47df291c",
      "45dbedeedfb7447ea942b231091bb29f",
      "2ec7de6c4e16478f80e7802ab8e5a2e4",
      "6c23a5fe1ed549329b15431817918f8d",
      "0a53a19926ec44b495fe0108b8a53d5b",
      "f0be5d04d3f1495585ccc6b46619a0bf",
      "1c7ee04d39714ca597fef65a65d8acdf",
      "33e6e73d07dd41cdba8282a0b7b2e17a",
      "c34061fd048741c7947d89bbe5daafa5",
      "6ae30e0fa7224d59ae3819fd89527770",
      "b92465e23aa94cbb92fe0631bc5ba2e0",
      "8dfc01750fb149bb837d8a2c51fc4198",
      "8efcba41f8f049f4bdc30c0416d52ad8",
      "31c3257c5d1d4120a1030c385b47ca45",
      "96f2dce18ed84dd799efb1780c7d81de",
      "480b04299b9f405e8f9a0f621451d8d4",
      "24f7e1b9446b43e8b17b4c54546ba20d",
      "b4757b916a094c76916e64d7841d2245",
      "c00a3e9d092f43e88d1bb9bf1c680bb9"
     ]
    },
    "id": "cphvac5ILezw",
    "outputId": "84b4a0fa-c727-4f73-bff8-0b9e7f2e22b1"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54fa72d8df5d471688bf58ce8d7d994d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6c23a5fe1ed549329b15431817918f8d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8dfc01750fb149bb837d8a2c51fc4198",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True, remove_columns=[\"idx\", \"sentence1\", \"sentence2\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "H4FaM4ZGLgad"
   },
   "source": [
    "Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.\n",
    "\n",
    "Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.\n",
    "\n",
    "Lastly, we remove the columns that our model will not use. We also need to rename the `label` column to `labels` as this is what our model will expect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "vfbSmdoWk5rG"
   },
   "outputs": [],
   "source": [
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hlw6EGsXlIHv"
   },
   "source": [
    "To double-check we only have columns that are accepted as arguments for the model we will instantiate, we can look at them here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "TmqrhXwYlD4l",
    "outputId": "173f8b3f-d1b0-46f0-f072-b886614d55a6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),\n",
       " 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),\n",
       " 'labels': ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], names_file=None, id=None),\n",
       " 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}"
      ]
     },
     "execution_count": 14,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets[\"train\"].features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "s_2D3L6kn7fn"
   },
   "source": [
    "The model we will be using is a `BertModelForSequenceClassification`. We can check its signature in the [Transformers documentation](https://huggingface.co/transformers/model_doc/bert.html#transformers.BertForSequenceClassification) and all seems to be right! The last step is to set our datasets in the `\"torch\"` format, so that each item in it is now a dictionary with tensor values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "Qy1VJb2dsJus"
   },
   "outputs": [],
   "source": [
    "tokenized_datasets.set_format(\"torch\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CSoNVxwUo_jt"
   },
   "source": [
    "## A first look at the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e6dcSOA_pzPg"
   },
   "source": [
    "Now that our data is ready, we can download the pretrained model and fine-tune it. Since all our tasks are about sentence classification, we use the `AutoModelForSequenceClassification` class. Like with the tokenizer, the from_pretrained method will download and cache the model for us. The only thing we have to specify is the number of labels for our problem (which is 2 here):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tlYXqxvtp-8Y",
    "outputId": "6c39c938-a3b3-423c-bc32-2a7c93cc120d"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lZpkm-f_qGya"
   },
   "source": [
    "The warning is telling us we are throwing away some weights (the vocab_transform and vocab_layer_norm layers) and randomly initializing some other (the pre_classifier and classifier layers). This is absolutely normal in this case, because we are removing the head used to pretrain the model on a masked language modeling objective and replacing it with a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.\n",
    "\n",
    "Note that we will are only creating the model here to look at it and debug problems. We will create the model we will train inside our training function: to train on TPU in colab, we have to create a big training function that will be executed on each code of the TPU. It's fine to do use the datasets defined before (they will be copied to each TPU core) but the model itself will need to be re-instantiated and placed on each device for it to work.\n",
    "\n",
    "Now to get the data we need to define our training and evaluation dataloaders. Again, we only create them here for debugging purposes, they will be re-instantiated in our training function, which is why we define a function that builds them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "Ll4PjC5nqFqk"
   },
   "outputs": [],
   "source": [
    "def create_dataloaders(train_batch_size=8, eval_batch_size=32):\n",
    "    train_dataloader = DataLoader(\n",
    "        tokenized_datasets[\"train\"], shuffle=True, batch_size=train_batch_size\n",
    "    )\n",
    "    eval_dataloader = DataLoader(\n",
    "        tokenized_datasets[\"validation\"], shuffle=False, batch_size=eval_batch_size\n",
    "    )\n",
    "    return train_dataloader, eval_dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4a_rrVs0siR7"
   },
   "source": [
    "Let's have a look at our train and evaluation dataloaders to check a batch can go through the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "l2krwQzMrdK2"
   },
   "outputs": [],
   "source": [
    "train_dataloader, eval_dataloader = create_dataloaders()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dwrNQ1Cbsp5i"
   },
   "source": [
    "We just loop through one batch. Since our datasets elements are dictionaries of tensors, it's the same for our batch and we can have a quick look at all the shapes. Note that this cell takes a bit of time to execute since we run a batch of our data through the model on the CPU (if you changed the checkpoint to a bigger model, it might take too much time so comment it out)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "RA-_yUyrritM",
    "outputId": "2717efe0-5081-42af-93a3-3bdd88334a25"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'attention_mask': torch.Size([8, 128]), 'input_ids': torch.Size([8, 128]), 'labels': torch.Size([8]), 'token_type_ids': torch.Size([8, 128])}\n"
     ]
    }
   ],
   "source": [
    "for batch in train_dataloader:\n",
    "    print({k: v.shape for k, v in batch.items()})\n",
    "    outputs = model(**batch)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fVCmBAxbtCU_"
   },
   "source": [
    "The output of our model is a `SequenceClassifierOutput`, with the `loss` (since we provided labels) and `logits` (of shape 8, our batch size, by 2, the number of labels)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "bccUGpZbsXdS",
    "outputId": "f79e2882-3813-4340-c2f5-703bc196743a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SequenceClassifierOutput([('loss', tensor(0.5652, grad_fn=<NllLossBackward>)),\n",
       "                          ('logits', tensor([[-0.5161, -0.1519],\n",
       "                                   [-0.5209, -0.1433],\n",
       "                                   [-0.5311, -0.1388],\n",
       "                                   [-0.5313, -0.1394],\n",
       "                                   [-0.5311, -0.1345],\n",
       "                                   [-0.5342, -0.1251],\n",
       "                                   [-0.5255, -0.1308],\n",
       "                                   [-0.5352, -0.1255]], grad_fn=<AddmmBackward>))])"
      ]
     },
     "execution_count": 20,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-9hiuBqMtU5W"
   },
   "source": [
    "The last piece we will need for the model evaluation is the metric. The `datasets` library provides a function `load_metric` that allows us to easily create a `datasets.Metric` object we can use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "id": "brJ6K4oBtpST"
   },
   "outputs": [],
   "source": [
    "metric = load_metric(\"glue\", \"mrpc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MG2tjiSEttNO"
   },
   "source": [
    "To use this object on some predictions we call the `compute` methode to get our metric results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "YVAPgl1BuB3s",
    "outputId": "c92e4117-8111-4006-f2b6-27e9765ed998"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.875, 'f1': 0.9333333333333333}"
      ]
     },
     "execution_count": 22,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = outputs.logits.detach().argmax(dim=-1)\n",
    "metric.compute(predictions=predictions, references=batch[\"labels\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rq1_mw0mucQV"
   },
   "source": [
    "Unsurpringly, our model with its random head does not perform well, which is why we need to fine-tune it!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "akoJx8puvTfV"
   },
   "source": [
    "## Fine-tuning the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "93vWWJa2wVEj"
   },
   "source": [
    "We are now ready to fine-tune this model on our dataset. As mentioned before, everything related to training needs to be in one big training function that will be executed on each TPU core, thanks to our `notebook_launcher`.\n",
    "\n",
    "It will use this dictionary of hyperparameters, so tweak anything you like in here!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "id": "rqyYx8SK8e0O"
   },
   "outputs": [],
   "source": [
    "hyperparameters = {\n",
    "    \"learning_rate\": 2e-5,\n",
    "    \"num_epochs\": 3,\n",
    "    \"train_batch_size\": 8, # Actual batch size will this x 8\n",
    "    \"eval_batch_size\": 32, # Actual batch size will this x 8\n",
    "    \"seed\": 42,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZR1F1G_n863W"
   },
   "source": [
    "The most important thing to remember for training on TPUs is that your `accelerator` object as to be defined inside the training function. If you define it in another cell that gets executed before the final launch (for debugging), you will need to restart your notebook as the line `accelerator = Accelerator()` needs to be executed for the first time inside the training function spwaned on each TPU core.\n",
    "\n",
    "This is because that line will look for a TPU device, and if you set it outside of the distributed training launched by `notebook_launcher`, it will perform setup that cannot be undone in your runtime and you will only have access to one TPU core until you restart the notebook.\n",
    "\n",
    "Since we can't explore each piece in separate cells, comments have been left in the code. This is all pretty standard and you will notice how little the code changes from a regular training loop! The main lines added are:\n",
    "\n",
    "- `accelerator = Accelerator()` to initalize the distributed setup,\n",
    "- sending all objects to `accelerator.prepare`,\n",
    "- replace `loss.backward()` with `accelerator.backward(loss)`,\n",
    "- use `accelerator.gather` to gather all predictions and labels before storing them in our list of predictions/labels,\n",
    "- truncate predictions and labels as the prepared evaluation dataloader has a few more samples to make batches of the same size on each process.\n",
    "\n",
    "The first three are for distributed training, the last two for distributed evaluation. If you don't care about distributed evaluation, you can also just replace that part by your standard evaluation loop launched on the main process only.\n",
    "\n",
    "Other changes (which are purely cosmetic to make the output of the training readable) are:\n",
    "\n",
    "- some logging behavior behind a `if accelerator.is_main_process:`,\n",
    "- disable the progress bar if `accelerator.is_main_process` is `False`,\n",
    "- use `accelerator.print` instead of `print`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "id": "QueNpNJF9eGq"
   },
   "outputs": [],
   "source": [
    "def training_function():\n",
    "    # Initialize accelerator\n",
    "    accelerator = Accelerator()\n",
    "\n",
    "    # To have only one message (and not 8) per logs of Transformers or Datasets, we set the logging verbosity\n",
    "    # to INFO for the main process only.\n",
    "    if accelerator.is_main_process:\n",
    "        datasets.utils.logging.set_verbosity_warning()\n",
    "        transformers.utils.logging.set_verbosity_info()\n",
    "    else:\n",
    "        datasets.utils.logging.set_verbosity_error()\n",
    "        transformers.utils.logging.set_verbosity_error()\n",
    "\n",
    "    train_dataloader, eval_dataloader = create_dataloaders(\n",
    "        train_batch_size=hyperparameters[\"train_batch_size\"], eval_batch_size=hyperparameters[\"eval_batch_size\"]\n",
    "    )\n",
    "    # The seed need to be set before we instantiate the model, as it will determine the random head.\n",
    "    set_seed(hyperparameters[\"seed\"])\n",
    "\n",
    "    # Instantiate the model, let Accelerate handle the device placement.\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)\n",
    "\n",
    "    # Instantiate optimizer\n",
    "    optimizer = AdamW(params=model.parameters(), lr=hyperparameters[\"learning_rate\"])\n",
    "\n",
    "    # Prepare everything\n",
    "    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n",
    "    # prepare method.\n",
    "    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "        model, optimizer, train_dataloader, eval_dataloader\n",
    "    )\n",
    "\n",
    "    num_epochs = hyperparameters[\"num_epochs\"]\n",
    "    # Instantiate learning rate scheduler after preparing the training dataloader as the prepare method\n",
    "    # may change its length.\n",
    "    lr_scheduler = get_linear_schedule_with_warmup(\n",
    "        optimizer=optimizer,\n",
    "        num_warmup_steps=100,\n",
    "        num_training_steps=len(train_dataloader) * num_epochs,\n",
    "    )\n",
    "\n",
    "    # Instantiate a progress bar to keep track of training. Note that we only enable it on the main\n",
    "    # process to avoid having 8 progress bars.\n",
    "    progress_bar = tqdm(range(num_epochs * len(train_dataloader)), disable=not accelerator.is_main_process)\n",
    "    # Now we train the model\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        for step, batch in enumerate(train_dataloader):\n",
    "            outputs = model(**batch)\n",
    "            loss = outputs.loss\n",
    "            accelerator.backward(loss)\n",
    "            \n",
    "            optimizer.step()\n",
    "            lr_scheduler.step()\n",
    "            optimizer.zero_grad()\n",
    "            progress_bar.update(1)\n",
    "\n",
    "        model.eval()\n",
    "        all_predictions = []\n",
    "        all_labels = []\n",
    "\n",
    "        for step, batch in enumerate(eval_dataloader):\n",
    "            with torch.no_grad():\n",
    "                outputs = model(**batch)\n",
    "            predictions = outputs.logits.argmax(dim=-1)\n",
    "\n",
    "            # We gather predictions and labels from the 8 TPUs to have them all.\n",
    "            all_predictions.append(accelerator.gather(predictions))\n",
    "            all_labels.append(accelerator.gather(batch[\"labels\"]))\n",
    "\n",
    "        # Concatenate all predictions and labels.\n",
    "        # The last thing we need to do is to truncate the predictions and labels we concatenated\n",
    "        # together as the prepared evaluation dataloader has a little bit more elements to make\n",
    "        # batches of the same size on each process.\n",
    "        all_predictions = torch.cat(all_predictions)[:len(tokenized_datasets[\"validation\"])]\n",
    "        all_labels = torch.cat(all_labels)[:len(tokenized_datasets[\"validation\"])]\n",
    "\n",
    "        eval_metric = metric.compute(predictions=all_predictions, references=all_labels)\n",
    "\n",
    "        # Use accelerator.print to print only on the main process.\n",
    "        accelerator.print(f\"epoch {epoch}:\", eval_metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K6LUeuUlEOvL"
   },
   "source": [
    "And we're ready for launch! It's super easy with the `notebook_launcher` from the Accelerate library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 647,
     "referenced_widgets": [
      "a91fb54c8a3d48d6a094c2a83f34c432",
      "55b8e8cb8646431eb67443739529200f",
      "07396159436a4202abf23156c1c5a0c6",
      "bf64d0787a1b4b5eac8a2db8f3bc69d6",
      "7b09dd3fb5e045428091b12bd82877c8",
      "ae0fc7ad41264c94898e22090ca22422",
      "fe4dd099be15498ea29b1013db02b503",
      "f9d76b1743c643939f5a59a30d378a2f"
     ]
    },
    "id": "ZBZdzN3lOEuL",
    "outputId": "340c815e-c6f0-47f9-e83e-87f98788c2c8"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file https://huggingface.co/bert-base-cased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/a803e0468a8fe090683bdc453f4fac622804f49de86d7cecaee92365d4a0f829.a64a22196690e0e82ead56f388a3ef3a50de93335926ccfa20610217db589307\n",
      "Model config BertConfig {\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.5.1\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 28996\n",
      "}\n",
      "\n",
      "loading weights file https://huggingface.co/bert-base-cased/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/092cc582560fc3833e556b3f833695c26343cb54b7e88cd02d40821462a74999.1f48cab6c959fc6c360d22bea39d06959e90f5b002e77e836d2da45464875cda\n",
      "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a91fb54c8a3d48d6a094c2a83f34c432",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=174.0), HTML(value='')))"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0: {'accuracy': 0.7303921568627451, 'f1': 0.8286604361370716}\n",
      "epoch 1: {'accuracy': 0.7892156862745098, 'f1': 0.8621794871794871}\n",
      "epoch 2: {'accuracy': 0.821078431372549, 'f1': 0.8726003490401396}\n"
     ]
    }
   ],
   "source": [
    "from accelerate import notebook_launcher\n",
    "\n",
    "notebook_launcher(training_function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "id": "Qaa-dIlEvCwY"
   },
   "outputs": [],
   "source": [
    ""
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Simple NLP Example",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "07396159436a4202abf23156c1c5a0c6": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_ae0fc7ad41264c94898e22090ca22422",
      "max": 174,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_7b09dd3fb5e045428091b12bd82877c8",
      "value": 174
     }
    },
    "0a53a19926ec44b495fe0108b8a53d5b": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "10cbaf06eb6e4dab836e92e51e543e1a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "1c7ee04d39714ca597fef65a65d8acdf": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b92465e23aa94cbb92fe0631bc5ba2e0",
      "placeholder": "​",
      "style": "IPY_MODEL_6ae30e0fa7224d59ae3819fd89527770",
      "value": " 1/1 [00:00&lt;00:00,  2.72ba/s]"
     }
    },
    "24f7e1b9446b43e8b17b4c54546ba20d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2ec7de6c4e16478f80e7802ab8e5a2e4": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "31c3257c5d1d4120a1030c385b47ca45": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_24f7e1b9446b43e8b17b4c54546ba20d",
      "max": 2,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_480b04299b9f405e8f9a0f621451d8d4",
      "value": 2
     }
    },
    "33e6e73d07dd41cdba8282a0b7b2e17a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "45dbedeedfb7447ea942b231091bb29f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "480b04299b9f405e8f9a0f621451d8d4": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "54fa72d8df5d471688bf58ce8d7d994d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_9c95e32846454e07987c69b53b49aedd",
       "IPY_MODEL_9a9e13893ca748d98077333b15b466ac"
      ],
      "layout": "IPY_MODEL_10cbaf06eb6e4dab836e92e51e543e1a"
     }
    },
    "55b8e8cb8646431eb67443739529200f": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "6ae30e0fa7224d59ae3819fd89527770": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "6c23a5fe1ed549329b15431817918f8d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_f0be5d04d3f1495585ccc6b46619a0bf",
       "IPY_MODEL_1c7ee04d39714ca597fef65a65d8acdf"
      ],
      "layout": "IPY_MODEL_0a53a19926ec44b495fe0108b8a53d5b"
     }
    },
    "7b09dd3fb5e045428091b12bd82877c8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "8dfc01750fb149bb837d8a2c51fc4198": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_31c3257c5d1d4120a1030c385b47ca45",
       "IPY_MODEL_96f2dce18ed84dd799efb1780c7d81de"
      ],
      "layout": "IPY_MODEL_8efcba41f8f049f4bdc30c0416d52ad8"
     }
    },
    "8efcba41f8f049f4bdc30c0416d52ad8": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "96f2dce18ed84dd799efb1780c7d81de": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c00a3e9d092f43e88d1bb9bf1c680bb9",
      "placeholder": "​",
      "style": "IPY_MODEL_b4757b916a094c76916e64d7841d2245",
      "value": " 2/2 [00:09&lt;00:00,  4.66s/ba]"
     }
    },
    "9a9e13893ca748d98077333b15b466ac": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_2ec7de6c4e16478f80e7802ab8e5a2e4",
      "placeholder": "​",
      "style": "IPY_MODEL_45dbedeedfb7447ea942b231091bb29f",
      "value": " 4/4 [00:10&lt;00:00,  2.56s/ba]"
     }
    },
    "9c95e32846454e07987c69b53b49aedd": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c1de1bf5955943b78e29e15c47df291c",
      "max": 4,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_f65d00a348454dfbad7ca4d3ffee1064",
      "value": 4
     }
    },
    "a91fb54c8a3d48d6a094c2a83f34c432": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_07396159436a4202abf23156c1c5a0c6",
       "IPY_MODEL_bf64d0787a1b4b5eac8a2db8f3bc69d6"
      ],
      "layout": "IPY_MODEL_55b8e8cb8646431eb67443739529200f"
     }
    },
    "ae0fc7ad41264c94898e22090ca22422": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b4757b916a094c76916e64d7841d2245": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "b92465e23aa94cbb92fe0631bc5ba2e0": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "bf64d0787a1b4b5eac8a2db8f3bc69d6": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f9d76b1743c643939f5a59a30d378a2f",
      "placeholder": "​",
      "style": "IPY_MODEL_fe4dd099be15498ea29b1013db02b503",
      "value": " 174/174 [02:28&lt;00:00,  1.47it/s]"
     }
    },
    "c00a3e9d092f43e88d1bb9bf1c680bb9": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c1de1bf5955943b78e29e15c47df291c": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c34061fd048741c7947d89bbe5daafa5": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f0be5d04d3f1495585ccc6b46619a0bf": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c34061fd048741c7947d89bbe5daafa5",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_33e6e73d07dd41cdba8282a0b7b2e17a",
      "value": 1
     }
    },
    "f65d00a348454dfbad7ca4d3ffee1064": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "f9d76b1743c643939f5a59a30d378a2f": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fe4dd099be15498ea29b1013db02b503": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
